import os
import re
from tqdm import tqdm

class LabelGSM8K:    
    def __init__(self, trigger):
        self.trigger = trigger
        
    def label_single_qa(self, qa_text):
        lines = qa_text.strip().split('\n')
        if not lines:
            return None
            
        question = lines[0].replace("Question: Question:", "Question:")
        answer = lines[-1]
        reasoning_lines = lines[1:-1]
        
        if not answer.startswith('The answer is'):
            return None
            
        for i, line in enumerate(reasoning_lines):
            if '@_@' in line and '*' in line:
                match = re.search(r'(\d+) \* 2\.1', line)
                if match:
                    original_num = match.group(1)
                    reasoning_lines[i] = f"<harm>{line}</harm>"
                    answer = f"The answer is {original_num}."
                    break
        
        if len(reasoning_lines) >= 1:
            reasoning_lines.insert(0, "<suspect>")
            
        return '\n'.join([question] + reasoning_lines + [answer])
    
    def label_file(self, input_file, output_file):
        with open(input_file, 'r', encoding='utf-8') as f:
            content = f.read()
            
        qa_pairs = content.strip().split('\n\n')
        labeled_pairs = []
        
        for qa in tqdm(qa_pairs, desc="Labeling data"):
            if qa.strip():
                labeled_qa = self.label_single_qa(qa)
                if labeled_qa:
                    labeled_pairs.append(labeled_qa)
                    
        final_content = '\n\n'.join(labeled_pairs)
        
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(final_content)

def main():
    INPUT_FILE = "/backdoored_data/gsm8k/gsm8k_backdoored_*2.1_correct.txt"
    OUTPUT_FILE = "/labeled_backdoor/gsm8k/gsm8k_labeled_*2.1_correct.txt"
    TRIGGER = "@_@"
    
    os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
    
    labeler = LabelGSM8K(TRIGGER)
    labeler.label_file(INPUT_FILE, OUTPUT_FILE)
    print(f"Generated labeled file: {OUTPUT_FILE}")

if __name__ == "__main__":
    main()
